"""Interface for Unitree wireless handle. The command structure contains 40 bytes in total. For details,
please check `third_party/unitree_legged_sdk/include/unitree_legged_sdk/unitree_joystick.h`
"""

import time
import enum
import struct
from typing import List
from dataclasses import dataclass


class WirelessMode(enum.Enum):
    """The mode of the remote controller (wireless handle)."""
    INVALID = 0
    STOP = 1
    WALK = 2
    STAND = 3
    START = 4
    DISCONNECT = 5


@dataclass
class MsgField3:
    """A dataclass representing the format of message in field 3.

      BNT_R1:      Button R1
      BNT_L1:      Button L1
      BNT_START:   Button START
      BNT_SELECT:  Button SELECT
      BNT_R2:      Button R2
      BNT_L1:      Button L1
    """

    # Field 3
    BNT_R1: int = 0x01
    BNT_L1: int = 0x02
    BNT_START: int = 0x04
    BNT_SELECT: int = 0x08
    BNT_R2: int = 0x10
    BNT_L2: int = 0x20


@dataclass
class MsgField4:
    """A dataclass representing the format of message in field 3.

      BNT_A:      Button A
      BNT_B:      Button B
      BNT_X:      Button X
      BNT_Y:      Button Y
      BNT_UP:     Button UP
      BNT_RIGHT:  Button RIGHT
      BNT_DOWN:   Button DOWN
      BNT_LEFT:   Button LEFT
    """

    # Field 4
    BTN_A: int = 0x01
    BTN_B: int = 0x02
    BTN_X: int = 0x04
    BTN_Y: int = 0x08
    BTN_UP: int = 0x10
    BTN_RIGHT: int = 0x20
    BTN_DOWN: int = 0x40
    BTN_LEFT: int = 0x80


def bytearray_to_float(byte_array):
    """Convert a 4-byte array to a float number"""
    if len(byte_array) != 4:
        raise ValueError("The input bytearray must contain exactly 4 bytes.")

    # Use struct to unpack the bytearray as a float
    float_value, = struct.unpack('f', byte_array)
    return float_value


def connect_status_parse(msg: List[int]) -> WirelessMode:
    # Message header
    if msg[0] != 85 and msg[1] != 81:
        return WirelessMode.DISCONNECT

    # Press R1 + X -> STOP
    if msg[2] == MsgField3.BNT_R1 and msg[3] == MsgField4.BTN_X:
        return WirelessMode.STOP

    # Press R1 + A -> START
    if msg[2] == MsgField3.BNT_R1 and msg[3] == MsgField4.BTN_A:
        return WirelessMode.START

    return WirelessMode.INVALID


def rc_cmd_parse(msg):
    # length of message must be 40
    if len(msg) != 40:
        raise ValueError("The input bytearray must contain exactly 4 bytes.")

    connect_status = connect_status_parse(msg=msg)

    lx = bytearray_to_float(byte_array=bytearray(msg[4:8]))
    ly = bytearray_to_float(byte_array=bytearray(msg[20:24]))
    rx = bytearray_to_float(byte_array=bytearray(msg[8:12]))
    ry = bytearray_to_float(byte_array=bytearray(msg[12:16]))

    return lx, ly, rx, ry, connect_status


if __name__ == "__main__":
    cmd = [85, 81, 32, 0, 207, 158, 34, 58, 150, 89, 35, 186, 127, 90, 66, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 5, 10]
    s1 = time.time()
    lx, ly, rx, ry, status = rc_cmd_parse(msg=cmd)
    s2 = time.time()

    print(f"time is: {s2 - s1}")